/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved.
 * Description: The description of class IROpFacade
 */

#include "framework/graph/core/infershape/op_ir_ctx.h"
#include "framework/graph/utils/op_desc_utils.h"
#include "framework/graph/op/data_flow_attr_defs.h"
#include "framework/graph/utils/tensor_utils.h"
#include "framework/graph/debug/ge_log.h"
#include "framework/graph/utils/attr_utils.h"

using namespace std;

namespace ge {
bool DataflowAttr::CheckAttr(const DataflowAttr& dataflowAttr)
{
    if (dtype_ != dataflowAttr.dtype_) {
        FMK_LOGE("dtype check fail, attr dtype:%d is not equal %d", dataflowAttr.dtype_, dtype_);
        return false;
    }
    if (shape_.size() != dataflowAttr.shape_.size()) {
        FMK_LOGW(
            "shape size check fail, attr shape size:%zu is not equal %zu", dataflowAttr.shape_.size(), shape_.size());
        return false;
    }
    for (size_t i = 0; i < dataflowAttr.shape_.size(); ++i) {
        if (shape_.at(i) != dataflowAttr.shape_.at(i)) {
            FMK_LOGE("shape value check fail, attr shape[%zu]:%ld is not equal %ld", i, dataflowAttr.shape_.at(i),
                shape_.at(i));
            return false;
        }
    }
    return true;
}

GraphErrCodeStatus InferContext::SetDataFlowAttr(const std::string& handle, const DataflowAttr& dataflowAttr)
{
    if (IsDataFlowAttrExist(handle)) {
        bool checkRet = dataflowAttrMap_.at(handle).CheckAttr(dataflowAttr);
        if (!checkRet) {
            FMK_LOGW("data flow attr is set before");
            auto& opDesc = dataflowAttrMap_.at(handle).GetOpDesc();
            GE_CHK_BOOL_RET_STATUS(
                ge::AttrUtils::SetListInt(opDesc, hiai::ATTR_NAME_ELEMENT_SHAPE, dataflowAttr.GetShape()), GRAPH_FAILED,
                "failed to set %s shape", opDesc.GetName().c_str());
            GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetInt(opDesc, hiai::ATTR_NAME_DTYPE, dataflowAttr.GetDataType()),
                GRAPH_FAILED, "failed to set %s type", opDesc.GetName().c_str());
        }
    }
    dataflowAttrMap_[handle] = dataflowAttr;
    return GRAPH_SUCCESS;
}

GraphErrCodeStatus InferContext::GetDataFlowAttr(const std::string& handle, DataflowAttr& dataflowAttr) const
{
    if (IsDataFlowAttrExist(handle)) {
        dataflowAttr = dataflowAttrMap_.at(handle);
        return GRAPH_SUCCESS;
    }
    return GRAPH_FAILED;
}

bool InferContext::IsDataFlowAttrExist(const std::string& handle) const
{
    bool isAttrExist = dataflowAttrMap_.count(handle) > 0;
    return isAttrExist;
}

GraphErrCodeStatus InferContext::VerifyInputSize(size_t expectSize)
{
    OpDesc& opDesc = opIRFacade_.GetOpDesc();
    size_t inputSize = opIRFacade_.GetInputsSize();
    if (inputSize != expectSize) {
        string msg = "Input size " + std::to_string(inputSize) + " is wrong, should be " +
            std::to_string(expectSize) + ", name : " + opDesc.GetName() + ", type : " + opDesc.GetType() + ".";
        AddVerifyErrMsg(msg);
        return GRAPH_FAILED;
    }
    return GRAPH_SUCCESS;
}
} // namespace ge
